# ruff: noqa: E501
import numpy as np
import unittest
from dataclasses import replace
from tinygrad import Tensor, Context, Device, dtypes
from tinygrad.uop.ops import Ops
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item, get_program

N = 512

def create_gemm_model(model_path:str, batch_size=N, in_size=N, out_size=N, bias=False):
  import onnx
  from onnx import helper, numpy_helper, TensorProto
  # Define input and output
  input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, [batch_size, in_size])
  output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch_size, out_size])

  # Create random weights and bias
  W_data = np.random.randn(in_size, out_size).astype(np.float32)
  W_init = numpy_helper.from_array(W_data, name="W")

  if bias:
    B_data = np.random.randn(out_size).astype(np.float32)
    B_init = numpy_helper.from_array(B_data, name="B")
    gemm_node = helper.make_node("Gemm", inputs=["input", "W", "B"], outputs=["output"], alpha=1.0, beta=1.0, transB=0)
    graph_def = helper.make_graph([gemm_node], "SingleGemmGraph", [input_tensor], [output_tensor], initializer=[W_init, B_init])
  else:
    gemm_node = helper.make_node("Gemm", inputs=["input", "W"], outputs=["output"], alpha=1.0, beta=1.0, transB=0)
    graph_def = helper.make_graph([gemm_node], "SingleGemmGraph", [input_tensor], [output_tensor], initializer=[W_init])

  # Create and save the model
  #model_def = helper.make_model(graph_def, producer_name="single_gemm_example")
  # TODO remove this once ORT supports 1.18.0
  model_def = helper.make_model(graph_def, producer_name="single_gemm_example", ir_version=10, opset_imports=[helper.make_opsetid("", 22)])
  onnx.save_model(model_def, model_path)
  return model_path

def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
  si = out.schedule()[-1]
  prg = get_program(si.ast, renderer=Device[Device.DEFAULT].renderer, opts=opts)
  if replace_src is not None:
    old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0]
    prg = replace(prg, src=replace_src + "/* DSP boilerplate */" + prg.src.split("/* DSP boilerplate */")[1].replace(old_name, "fxn"))
  ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
  for _ in range(run_count): ei.run(wait=True)

def get_quantized_model(sz):
  from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
  class FakeDataReader(CalibrationDataReader):
    def __init__(self): self.cnt = 0
    def get_next(self) -> dict:
      self.cnt += 1
      if self.cnt == 100: return None
      return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)}
  out_file = "/tmp/test_out.onnx"
  quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file,
                  FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False,
                  activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
                  extra_options={"ActivationSymmetric": False})
  return out_file

@unittest.skip("this is broken")
@unittest.skipIf(Device.DEFAULT != "CPU", "only tests for CPU")
class TestQuantizeOnnxCPU(unittest.TestCase):
  def test_quant_128(self, sz=128):
    try:
      import onnx # noqa: F401 # pylint: disable=unused-import
    except ImportError:
      raise unittest.SkipTest()
    from tinygrad.nn.onnx import OnnxRunner
    out_file = get_quantized_model(sz)
    run_onnx = OnnxRunner(out_file)
    inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))
    with Context(QUANTIZE=1):
      sched = run_onnx({"input":inp})["output"].schedule()
      ei = lower_schedule_item(sched[-2])
      daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_REG]
      assert all(u.dtype.scalar() is dtypes.int for u in daccs)

@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
class TestQuantizeOnnx(unittest.TestCase):
  def test_quant_128(self): self.test_quant(128)
  def test_quant(self, sz=512):
    from examples.benchmark_onnx import load_onnx_model
    # divide is ~1500-2000 without reduce_range, 750-900 with it
    out_file = get_quantized_model(sz)
    run_onnx_jit, _ = load_onnx_model(out_file)
    run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)))

  def test_prequant_conv2d_1x1(self):
    X = Tensor(np.random.uniform(0, 255, size=(1, 32, 128, 128)).astype(np.uint8))
    W = Tensor(np.random.uniform(0, 255, size=(64, 32, 1, 1)).astype(np.uint8))
    out = X.conv2d(W, dtype=X.dtype)
    opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
    sexec(out, opts)

  def test_prequant_gemm(self):
    N = 512
    X = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8))
    W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8))
    out = X.matmul(W, dtype=X.dtype)
    opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
    sexec(out, opts)

  # TODO: this has to work
  def test_prequant_gemm_intacc_early(self, xi=np.int8, wi=np.int8):
    N = 512
    X = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(xi))
    W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(wi))
    # this divide is interesting and forces the accumulator to actually be an int
    out = (X.cast("int").matmul(W.cast("int"))//1000).cast("int8")
    opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
    sexec(out, opts)

  def test_prequant_gemm_handcode(self):
    src = """typedef int int128 __attribute__((aligned(512),vector_size(512)));
    typedef int int32 __attribute__((aligned(128),vector_size(128)));
    typedef int int64 __attribute__((aligned(256),vector_size(256)));
    typedef unsigned char unsigned_char4 __attribute__((aligned(4),vector_size(4)));
    typedef signed char signed_char128 __attribute__((aligned(128),vector_size(128)));
    typedef unsigned char unsigned_char128 __attribute__((aligned(128),vector_size(128)));
    typedef unsigned char unsigned_char256 __attribute__((aligned(256),vector_size(256)));
    union V256 {
      unsigned_char256 vec256;
      struct {
        unsigned_char128 lo128;
        unsigned_char128 hi128;
      };
    };
    __attribute__((noinline)) void fxn(unsigned char* restrict __attribute__((align_value(128))) data0,
                                       unsigned char* restrict __attribute__((align_value(128))) data1,
                                       signed char* restrict __attribute__((align_value(128))) data2) {
      for (int ridx0 = 0; ridx0 < 512; ridx0++) {
        int alu0 = (ridx0<<9);
        for (int ridx1 = 0; ridx1 < 4; ridx1++) {
          int alu1 = (ridx1<<7);
          int32 acc0 = __builtin_HEXAGON_V6_vd0_128B();
          int32 acc1 = __builtin_HEXAGON_V6_vd0_128B();
          int32 acc2 = __builtin_HEXAGON_V6_vd0_128B();
          int32 acc3 = __builtin_HEXAGON_V6_vd0_128B();

          for (int ridx2 = 0; ridx2 < 128; ridx2++) {
            unsigned_char4 val0 = *((unsigned_char4*)((data1+(alu0+(ridx2<<2)))));
            int alu2 = (alu1+(ridx2<<11));
            signed_char128 x0 = *((signed_char128*)((data2+alu2)));
            signed_char128 x1 = *((signed_char128*)((data2+(alu2+512))));
            signed_char128 x2 = *((signed_char128*)((data2+(alu2+1024))));
            signed_char128 x3 = *((signed_char128*)((data2+(alu2+1536))));

            union V256 ss01;
            // ss01.lo128 = (x0[0], x1[0], x0[2], x1[2], x0[4], x1[4], ...)
            // ss01.hi128 = (x0[1], x1[1], x0[3], x1[3], x0[5], x1[5], ...)
            ss01.vec256 = __builtin_HEXAGON_V6_vshufoeb_128B(x1, x0);

            union V256 ss23;
            // ss23.lo128 = (x2[0], x3[0], x2[2], x3[2], x2[4], x3[4], ...)
            // ss23.hi128 = (x2[1], x3[1], x2[3], x3[3], x2[5], x3[5], ...)
            ss23.vec256 = __builtin_HEXAGON_V6_vshufoeb_128B(x3, x2);

            union V256 sslo;
            // sslo.lo128 = (x0[0], x1[0], x2[0], x3[0], x0[4], x1[4], ...)
            // sslo.hi128 = (x0[2], x1[2], x2[2], x3[2], x0[6], x1[6], ...)
            sslo.vec256 = __builtin_HEXAGON_V6_vdealvdd_128B(ss23.lo128, ss01.lo128, 2);

            union V256 sshi;
            // sshi.lo128 = (x0[1], x1[1], x2[1], x3[1], x0[5], x1[5], ...)
            // sshi.hi128 = (x0[3], x1[3], x2[3], x3[3], x0[7], x1[7], ...)
            sshi.vec256 = __builtin_HEXAGON_V6_vdealvdd_128B(ss23.hi128, ss01.hi128, 2);

            //unsigned_char128 w0 = (unsigned_char128){val0[0],val0[1],val0[2],val0[3],val0[0],val0[1],val0[2],val0[3],...
            unsigned_char128 w0 = __builtin_HEXAGON_V6_lvsplatw_128B(*((unsigned int*)&val0));

            acc0 = __builtin_HEXAGON_V6_vrmpybusv_acc_128B(acc0, w0, sslo.lo128);
            acc1 = __builtin_HEXAGON_V6_vrmpybusv_acc_128B(acc1, w0, sshi.lo128);
            acc2 = __builtin_HEXAGON_V6_vrmpybusv_acc_128B(acc2, w0, sslo.hi128);
            acc3 = __builtin_HEXAGON_V6_vrmpybusv_acc_128B(acc3, w0, sshi.hi128);
          }
          acc0 /= 1000;
          acc1 /= 1000;
          acc2 /= 1000;
          acc3 /= 1000;
          // ','.join([f"acc{j}[{i}]" for i in range(32) for j in range(4)])
          // acc0[0], acc0[1], acc0[2], ..... acc3[30], acc3[31]
          unsigned_char128 packed = __builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B(acc3, acc2),
                                                                           __builtin_HEXAGON_V6_vpackwh_sat_128B(acc1, acc0));
          packed = __builtin_HEXAGON_V6_vshuffb_128B(packed);
          packed = __builtin_HEXAGON_V6_vshuffb_128B(packed);
          // acc0[0], acc1[0], acc2[0], ..... acc2[31], acc3[31]
          *((unsigned_char128*)((data0+(alu0+alu1)))) = packed;
        }
      }
    }"""
    self.test_prequant_gemm_intacc(np.uint8, np.int8, src)

  def test_prequant_gemm_intacc_32(self):
    opts = [Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=0, arg=4), Opt(op=OptOps.UNROLL, axis=0, arg=0)]
    self.test_prequant_gemm_intacc(np.uint8, np.int8, N=32, opts=opts)
  def test_prequant_gemm_intacc_128(self): self.test_prequant_gemm_intacc(np.uint8, np.int8, N=128)
  def test_prequant_gemm_intacc_256(self): self.test_prequant_gemm_intacc(np.uint8, np.int8, N=256)
  def test_prequant_gemm_intacc(self, xi=np.uint8, wi=np.uint8, replace_src=None, N=512, clip=True, opts=None):
    X = Tensor(m1:=(np.random.uniform(0, 255, size=(N,N)).astype(xi))).realize()
    W = Tensor(m2:=(np.random.uniform(0, 255, size=(N,N)).astype(wi))).realize()
    tg_dtype = dtypes.int8 if xi == np.int8 else dtypes.uint8
    out = (X.int().matmul(W.int())//1000)
    if clip: out = out.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype))
    out = out.cast(tg_dtype)
    opts = [Opt(op=OptOps.UPCAST, axis=1, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)] if opts is None else opts
    sexec(out, opts, replace_src, run_count=1)
    tout = out.numpy()
    mout = ((m1.astype(np.int32) @ m2.astype(np.int32)) // 1000)
    if clip: mout = mout.clip(dtypes.min(tg_dtype),dtypes.max(tg_dtype))
    mout = mout.astype(xi)
    print(tout)
    print(mout)
    np.testing.assert_equal(tout, mout)

  def test_prequant_gemm_intacc_wi(self): self.test_prequant_gemm_intacc(wi=np.int8)
  def test_prequant_gemm_intacc_xiwi(self): self.test_prequant_gemm_intacc(xi=np.int8, wi=np.int8)
  def test_prequant_gemm_intacc_xiwi_noclip(self): self.test_prequant_gemm_intacc(xi=np.int8, wi=np.int8, clip=False)

  def test_prequant_gemv(self):
    N = 2048
    X = Tensor(np.random.uniform(0, 255, size=(1,N)).astype(np.uint8)).realize()
    W = Tensor(np.random.uniform(0, 255, size=(N,N)).astype(np.uint8)).realize()
    #out = X.cast(dtypes.int) @ W.cast(dtypes.int)
    #out = X @ W
    out = X.matmul(W, dtype=X.dtype)
    opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128), Opt(op=OptOps.UNROLL, axis=0, arg=4)]
    sexec(out, opts)

if __name__ == "__main__":
  unittest.main()
